import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import warnings
warnings.filterwarnings("ignore")
from multiprocessing import cpu_count, Pool
tr = pd.read_pickle('../data/train.pkl')
tr_log = pd.read_pickle('../data/train_log.pkl')
tr.head()
tr.tail()
tr_log.head(20)
df = tr_log[tr_log.object_id==615]
df['date'] = df.mjd.astype(int)
pd.pivot_table(df, index=['date'], columns=['passband'], values=['flux']).plot(marker="o", legend=True)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
pd.pivot_table(df, index=['date'], columns=['passband'], values=['flux']).reset_index().head()
def plt_obj(oid=None, save=False, path=None, norm=False):
if oid is None:
oid = np.random.choice(tr.object_id)
df = tr_log[tr_log.object_id==oid]
if norm:
df.flux /= df.flux.max()
target = tr.loc[tr.object_id==oid, 'target'].values[0]
photoz = tr.loc[tr.object_id==oid, 'hostgal_photoz'].values[0]
df['date'] = df.mjd.astype(int)
pd.pivot_table(df, index=['date'], columns=['passband'], values=['flux']).plot(marker="o", legend=True)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title(f'oid:{oid} target:{target} photoz:{photoz}')
if save==True and path is not None:
plt.savefig(path)
return
plt_obj(615)
plt_obj(615, norm=True)
classes = [6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95]
li = []
for c in classes:
li += tr[tr.target==c].sample(20).object_id.tolist()
for i in li:
plt_obj(i)
for args in argss:
multi(args)